{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import concurrent.futures\n",
    "import shutil\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from scipy.stats import fisher_exact, wilcoxon\n",
    "\n",
    "from virtual_lab.agent import Agent\n",
    "from virtual_lab.constants import CONSISTENT_TEMPERATURE, CREATIVE_TEMPERATURE\n",
    "from virtual_lab.prompts import create_merge_prompt\n",
    "from virtual_lab.run_meeting import run_meeting\n",
    "from virtual_lab.utils import load_summaries\n",
    "\n",
    "from nanobody_constants import (\n",
    "    background_prompt,\n",
    "    nanobody_prompt,\n",
    "    discussions_phase_to_dir,\n",
    "    generic_team_lead,\n",
    "    generic_team,\n",
    "    principal_investigator,\n",
    "    immunologist,\n",
    "    team_members,\n",
    "    num_rounds,\n",
    "    num_iterations,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ablations_dir = discussions_phase_to_dir[\"ablations\"]\n",
    "ablations_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "fontsize = 18\n",
    "figsize = (10, 6)\n",
    "colors = ['#2166AC', '#FDDBC7', '#B2182B']  # Blue, Light Pink, Dark Red"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generic vs specialized agent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare a team of generic agents to a team of specialized agents."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Project specification - prompts\n",
    "project_specification_agenda = f\"{background_prompt} Please create an antibody/nanobody design approach to solve this problem. Decide whether you will design antibodies or nanobodies. For your choice, decide whether you will design the antibodies/nanobodies de novo or whether you will modify existing antibodies/nanobodies. If modifying existing antibodies/nanobodies, please specify which antibodies/nanobodies to start with as good candidates for targeting the newest variant of the SARS-CoV-2 spike protein. If designing antibodies/nanobodies de novo, please describe how you will propose antibody/nanobody candidates.\"\n",
    "\n",
    "project_specification_questions = (\n",
    "    \"Will you design standard antibodies or nanobodies?\",\n",
    "    \"Will you design antibodies/nanobodies de novo or will you modify existing antibodies/nanobodies (choose only one)?\",\n",
    "    \"If modifying existing antibodies/nanobodies, which precise antibodies/nanobodies will you modify (please list 3-4)?\",\n",
    "    \"If designing antibodies/nanobodies de novo, how exactly will you propose antibody/nanobody candidates?\",\n",
    ")\n",
    "\n",
    "project_specification_dir = ablations_dir / \"project_specification\"\n",
    "project_specification_dir.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copy the pre-existing five scientist project specification discussions\n",
    "shutil.copytree(\n",
    "    discussions_phase_to_dir[\"project_specification\"],\n",
    "    project_specification_dir / \"scientist_team\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Project specification - discussion (generic team)\n",
    "with concurrent.futures.ThreadPoolExecutor() as executor:\n",
    "    concurrent.futures.wait([\n",
    "        executor.submit(\n",
    "            run_meeting,\n",
    "            meeting_type=\"team\",\n",
    "            team_lead=generic_team_lead,\n",
    "            team_members=generic_team,\n",
    "            agenda=project_specification_agenda,\n",
    "            agenda_questions=project_specification_questions,\n",
    "            save_dir=project_specification_dir / \"generic_team\",\n",
    "            save_name=f\"discussion_{iteration_num + 1}\",\n",
    "            temperature=CREATIVE_TEMPERATURE,\n",
    "            num_rounds=num_rounds,\n",
    "        ) for iteration_num in range(num_iterations)\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Project specification - merge (generic team)\n",
    "project_specification_summaries = load_summaries(\n",
    "    discussion_paths=list((project_specification_dir / \"generic_team\").glob(\"discussion_[1|2|3|4|5].json\")))\n",
    "print(f\"Number of summaries: {len(project_specification_summaries)}\")\n",
    "\n",
    "project_specification_merge_prompt = create_merge_prompt(\n",
    "    agenda=project_specification_agenda,\n",
    "    agenda_questions=project_specification_questions,\n",
    ")\n",
    "\n",
    "run_meeting(\n",
    "    meeting_type=\"individual\",\n",
    "    team_member=generic_team_lead,\n",
    "    summaries=project_specification_summaries,\n",
    "    agenda=project_specification_merge_prompt,\n",
    "    save_dir=project_specification_dir / \"generic_team\",\n",
    "    save_name=\"merged\",\n",
    "    temperature=CONSISTENT_TEMPERATURE,\n",
    "    num_rounds=num_rounds,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Project specification - discussion (scaled up generic team and scientist team)\n",
    "for team_lead, team, team_name in [\n",
    "    (generic_team_lead, generic_team, \"generic_team\"),\n",
    "    (principal_investigator, team_members, \"scientist_team\"),\n",
    "]:\n",
    "    with concurrent.futures.ThreadPoolExecutor() as executor:\n",
    "        concurrent.futures.wait([\n",
    "            executor.submit(\n",
    "                run_meeting,\n",
    "                meeting_type=\"team\",\n",
    "                team_lead=team_lead,\n",
    "                team_members=team,\n",
    "                agenda=project_specification_agenda,\n",
    "                agenda_questions=project_specification_questions,\n",
    "                save_dir=project_specification_dir / team_name,\n",
    "                save_name=f\"discussion_{iteration_num + 1}\",\n",
    "                temperature=CREATIVE_TEMPERATURE,\n",
    "                num_rounds=num_rounds,\n",
    "            ) for iteration_num in range(num_iterations, 50)\n",
    "        ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load design choices\n",
    "design_choices = pd.read_csv(project_specification_dir / \"design_choices.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Analyze statistical significance of the de novo vs modification choice\n",
    "generic_protein_counts = design_choices[\"Generic Team: Antibody or Nanobody\"].value_counts()\n",
    "generic_design_counts = design_choices[\"Generic Team: New or Existing\"].value_counts()\n",
    "\n",
    "scientist_protein_counts = design_choices[\"Scientist Team: Antibody or Nanobody\"].value_counts()\n",
    "scientist_design_counts = design_choices[\"Scientist Team: New or Existing\"].value_counts()\n",
    "\n",
    "protein_table = [\n",
    "    [generic_protein_counts.get(\"Nanobody\", 0), generic_protein_counts.sum() - generic_protein_counts.get(\"Nanobody\", 0)],\n",
    "    [scientist_protein_counts.get(\"Nanobody\", 0), scientist_protein_counts.sum() - scientist_protein_counts.get(\"Nanobody\", 0)],\n",
    "]\n",
    "\n",
    "print(pd.DataFrame(protein_table, index=[\"Generic Team\", \"Scientist Team\"], columns=[\"Nanobody\", \"Antibody\"]))\n",
    "print()\n",
    "\n",
    "protein_oddsratio, protein_p_value = fisher_exact(protein_table)\n",
    "print(\"Fisher's Exact Test Results:\")\n",
    "print(f\"Odds Ratio: {protein_oddsratio:.3f}, p-value: {protein_p_value:.3f}\")\n",
    "print()\n",
    "\n",
    "design_table = [\n",
    "    [generic_design_counts.get(\"Existing\", 0), generic_design_counts.sum() - generic_design_counts.get(\"Existing\", 0)],\n",
    "    [scientist_design_counts.get(\"Existing\", 0), scientist_design_counts.sum() - scientist_design_counts.get(\"Existing\", 0)],\n",
    "]\n",
    "\n",
    "print(pd.DataFrame(design_table, index=[\"Generic Team\", \"Scientist Team\"], columns=[\"Existing\", \"Both or New\"]))\n",
    "print()\n",
    "\n",
    "design_oddsratio, design_p_value = fisher_exact(design_table)\n",
    "print(\"Fisher's Exact Test Results:\")\n",
    "print(f\"Odds Ratio: {design_oddsratio:.3f}, p-value: {design_p_value:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add statistical significance annotations\n",
    "def annotate_significance(ax, p_value, x1, x2, y):\n",
    "    if p_value < 0.001:\n",
    "        significance = '***'\n",
    "    elif p_value < 0.01:\n",
    "        significance = '**'\n",
    "    elif p_value < 0.05:\n",
    "        significance = '*'\n",
    "    else:\n",
    "        significance = 'ns'  # Not significant\n",
    "\n",
    "    # Add horizontal line\n",
    "    ax.plot([x1, x2], [y-2, y-2], color='black', linewidth=1)\n",
    "    # Add vertical lines\n",
    "    ax.plot([x1, x1], [y-3, y-2], color='black', linewidth=1)\n",
    "    ax.plot([x2, x2], [y-3, y-2], color='black', linewidth=1)\n",
    "    # Add significance marker\n",
    "    ax.text((x1 + x2) / 2, y-1, significance, ha='center', va='bottom', fontsize=fontsize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot protein choices\n",
    "fig, ax = plt.subplots(figsize=figsize)\n",
    "\n",
    "plt.bar([\"Generic Team\", \"Scientist Team\"], \n",
    "    [generic_protein_counts.get(\"Nanobody\", 0), scientist_protein_counts.get(\"Nanobody\", 0)],\n",
    "    label=\"Nanobody\", color=colors[0])\n",
    "plt.bar([\"Generic Team\", \"Scientist Team\"],\n",
    "    [generic_protein_counts.get(\"Antibody\", 0), scientist_protein_counts.get(\"Antibody\", 0)],\n",
    "    bottom=[generic_protein_counts.get(\"Nanobody\", 0), scientist_protein_counts.get(\"Nanobody\", 0)],\n",
    "    label=\"Antibody\", color=colors[2])\n",
    "\n",
    "annotate_significance(ax, protein_p_value, 0, 1, 55)\n",
    "\n",
    "plt.ylabel(\"Count\", fontsize=fontsize)\n",
    "plt.xticks(fontsize=fontsize)\n",
    "plt.yticks(fontsize=fontsize)\n",
    "plt.legend(fontsize=fontsize)\n",
    "plt.ylim(0, 58)\n",
    "plt.savefig(project_specification_dir / \"protein_counts.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot design choices\n",
    "fig, ax = plt.subplots(figsize=figsize)\n",
    "\n",
    "plt.bar([\"Generic Team\", \"Scientist Team\"], \n",
    "    [generic_design_counts.get(\"Existing\", 0), scientist_design_counts.get(\"Existing\", 0)],\n",
    "    label=\"Existing\", color=colors[0])\n",
    "plt.bar([\"Generic Team\", \"Scientist Team\"],\n",
    "    [generic_design_counts.get(\"Both\", 0), scientist_design_counts.get(\"Both\", 0)],\n",
    "    bottom=[generic_design_counts.get(\"Existing\", 0), scientist_design_counts.get(\"Existing\", 0)],\n",
    "    label=\"Both\", color=colors[1])\n",
    "plt.bar([\"Generic Team\", \"Scientist Team\"],\n",
    "    [generic_design_counts.get(\"New\", 0), scientist_design_counts.get(\"New\", 0)],\n",
    "    bottom=[generic_design_counts.get(\"Existing\", 0) + generic_design_counts.get(\"Both\", 0), \n",
    "            scientist_design_counts.get(\"Existing\", 0) + scientist_design_counts.get(\"Both\", 0)],\n",
    "    label=\"New\", color=colors[2])\n",
    "\n",
    "annotate_significance(ax, design_p_value, 0, 1, 55)\n",
    "\n",
    "plt.ylabel(\"Count\", fontsize=fontsize)\n",
    "plt.xticks(fontsize=fontsize)\n",
    "plt.yticks(fontsize=fontsize)\n",
    "plt.legend(fontsize=fontsize)\n",
    "plt.ylim(0, 58)\n",
    "plt.savefig(project_specification_dir / \"design_counts.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Specialized vs finetuned agent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Question and Answer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See `finetune_agents.ipynb` for the question and answer pairs and evaluation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Knowledge"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare the knowledge of a specialized agent (Immunologist) and a finetuned agent (Immunologist-FT)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "immunologist_ft = Agent(\n",
    "    title=immunologist.title,\n",
    "    expertise=immunologist.expertise,\n",
    "    goal=immunologist.goal,\n",
    "    role=immunologist.role,\n",
    "    model=\"ft:gpt-4o-mini-2024-07-18:personal:sars-cov-2-variants-kp-3-and-jn-1:AzeSvPAe\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agenda = \"What are some of the RBD mutations of the KP.3 and JN.1 variants of the SARS-CoV-2 spike protein?\"\n",
    "\n",
    "agents_names = [\n",
    "    (immunologist, \"immunologist\"),\n",
    "    (immunologist_ft, \"immunologist_ft\"),\n",
    "]\n",
    "\n",
    "for agent, agent_name in agents_names:\n",
    "    run_meeting(\n",
    "        meeting_type=\"individual\", \n",
    "        team_member=agent,\n",
    "        agenda=agenda,\n",
    "        save_dir=ablations_dir,\n",
    "        save_name=f\"knowledge_{agent_name}\",\n",
    "        temperature=CONSISTENT_TEMPERATURE,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Single vs multi agent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Look at `discussions/alphafold/discussion_4.md` to compare the effect of a single agent vs a multi-agent interaction. The first round answer is the product of just a single specialized agent (Computational Biologist) while the final round answer is the product of the multi-agent interaction (Computational Biologist and Scientific Critic).\n",
    "\n",
    "Also, see below for a human analysis of the effects of single agent vs multi-agent interactions for coding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coding_eval = pd.read_csv(ablations_dir / \"coding_eval.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oddsratio, p_value = wilcoxon(coding_eval[\"# Enhancements\"], alternative=\"greater\")\n",
    "print(\"Wilcoxon Signed-Rank Test Results:\")\n",
    "print(f\"Statistic: {oddsratio:.3f}, p-value: {p_value:.3e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "plt.figure(figsize=figsize)\n",
    "plt.axhline(y=0, color='black', linestyle='-', linewidth=2)\n",
    "sns.swarmplot(data=coding_eval, x='Topic', y='# Enhancements', size=18)\n",
    "plt.xlabel(\"\")\n",
    "plt.xticks(fontsize=fontsize)\n",
    "plt.yticks(fontsize=fontsize)\n",
    "plt.ylabel(\"# Enhancements\", fontsize=fontsize)\n",
    "plt.savefig(ablations_dir / \"coding_eval.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Single vs parallel meetings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To see the effect of parallel meetings, compare the individual discussions (e.g., `discussion_1.md` through `discussion_5.md`) with the result of the parallel meeting (e.g., `merged.md`) for any of the discussions in the `discussions` directory."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompt engineering"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Illustrate the effect of prompt engineering on tool selection. For the answer using the original agenda and agenda questions, see `discussions/tools_selection/merged.md`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Tools selection - setup\n",
    "tools_selection_prior_summaries = load_summaries(\n",
    "    discussion_paths=[discussions_phase_to_dir[\"project_specification\"] / \"merged.json\"])\n",
    "print(f\"Number of prior summaries: {len(tools_selection_prior_summaries)}\")\n",
    "\n",
    "tools_selection_dir = ablations_dir / \"tools_selection\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Tools selection - agendas and questions\n",
    "tools_selection_agendas = [\n",
    "    f\"{background_prompt} {nanobody_prompt} Now you need to select machine learning and/or computational tools to implement this nanobody design approach. Please list several tools (5-10) that would be relevant to this nanobody design approach and how they could be used in the context of this project.\",\n",
    "    f\"{background_prompt} {nanobody_prompt} Now you need to select tools to implement this nanobody design approach. Please list several tools (5-10) that would be relevant to this nanobody design approach and how they could be used in the context of this project.\",\n",
    "    f\"{background_prompt} {nanobody_prompt} Now you need to select tools to implement this nanobody design approach. Please list several tools that would be relevant to this nanobody design approach and how they could be used in the context of this project.\",\n",
    "    f\"{background_prompt} {nanobody_prompt} Now you need to select tools to implement this nanobody design approach. Please list several tools that would be relevant to this nanobody design approach.\",\n",
    "]\n",
    "\n",
    "tools_selection_questions = [\n",
    "    (\n",
    "        \"What machine learning and/or computational tools could be used for this nanobody design approach (list 5-10)?\",\n",
    "        \"For each tool, how could it be used for designing modified nanobodies?\",\n",
    "    ),\n",
    "    (\n",
    "        \"What tools could be used for this nanobody design approach (list 5-10)?\",\n",
    "        \"For each tool, how could it be used for designing modified nanobodies?\",\n",
    "    ),\n",
    "    (\n",
    "        \"What tools could be used for this nanobody design approach?\",\n",
    "        \"For each tool, how could it be used for designing modified nanobodies?\",\n",
    "    ),\n",
    "    (\n",
    "        \"What tools could be used for this nanobody design approach?\",\n",
    "    ),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Tools selection – ablations\n",
    "for i, (agenda, agenda_questions) in enumerate(zip(tools_selection_agendas, tools_selection_questions)):\n",
    "    if i != 3:\n",
    "        continue\n",
    "\n",
    "    # Create save directory\n",
    "    save_dir = tools_selection_dir / f\"prompt_{i + 1}\"\n",
    "    save_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    # Run meetings\n",
    "    with concurrent.futures.ThreadPoolExecutor() as executor:\n",
    "        concurrent.futures.wait([\n",
    "            executor.submit(\n",
    "                run_meeting,\n",
    "                meeting_type=\"team\",\n",
    "                team_lead=principal_investigator,\n",
    "                team_members=team_members,\n",
    "                summaries=tools_selection_prior_summaries,\n",
    "                agenda=agenda,\n",
    "                agenda_questions=agenda_questions,\n",
    "                save_dir=save_dir,\n",
    "                save_name=f\"discussion_{iteration_num + 1}\",\n",
    "                temperature=CREATIVE_TEMPERATURE,\n",
    "                num_rounds=num_rounds,\n",
    "            ) for iteration_num in range(num_iterations)\n",
    "        ])\n",
    "\n",
    "    # Load summaries\n",
    "    summaries = load_summaries(discussion_paths=list(save_dir.glob(\"discussion_*.json\")))\n",
    "    print(f\"Number of summaries: {len(summaries)}\")\n",
    "\n",
    "    # Create merge prompt\n",
    "    merge_prompt = create_merge_prompt(agenda=agenda, agenda_questions=agenda_questions)\n",
    "\n",
    "    # Run merge meeting\n",
    "    run_meeting(\n",
    "        meeting_type=\"individual\",\n",
    "        team_member=principal_investigator,\n",
    "        summaries=summaries,\n",
    "        agenda=merge_prompt,\n",
    "        save_dir=save_dir,\n",
    "        save_name=\"merged\",\n",
    "        temperature=CONSISTENT_TEMPERATURE,\n",
    "        num_rounds=num_rounds,\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "virtual_lab",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
